package com.bjy.qa.agent.tester;

import com.bjy.qa.agent.enumtype.CatalogType;
import com.bjy.qa.agent.exception.MyException;
import com.bjy.qa.agent.tester.handler.perf.PerfTesterTaskBootThread;
import com.bjy.qa.agent.tester.handler.tester.TesterTaskBootThread;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;

import static com.bjy.qa.agent.tester.SuiteListener.runningTestsMap;

/**
 * 任务管理器
 */
public class TaskManager {
    private static final Logger logger = LoggerFactory.getLogger(TaskManager.class);

    private static ConcurrentHashMap<String, Thread> bootThreadsMap = new ConcurrentHashMap<>(); // key是boot的线程名，value是boot线程本身
    private static ConcurrentHashMap<String, Set<Thread>> childThreadsMap = new ConcurrentHashMap<>(); // key是boot的线程名，value是boot线程启动的线程，因为是守护线程，所以当boot被停止后，child线程也会停止
    private static Set<String> runningRidSet = Collections.synchronizedSet(new HashSet<>()); // 正在运行的rid记录
    private static final Lock lock = new ReentrantLock(); // 用于锁定boot线程，防止重复启动

    /**
     * 判断是否有正在运行的 rid
     * @param rid
     * @return
     */
    public static boolean ridRunning(Integer rid) {
        return runningRidSet.contains(rid);
    }

    /**
     * 启动 boot 线程
     * @param bootThread boot线程
     */
    public static void startBootThread(Thread bootThread) {
        bootThread.start();
        addBootThread(bootThread.getName(), bootThread);

        String[] split = bootThread.getName().split("-");
        runningRidSet.add(split[4]);
    }

    /**
     * 启动 boot 线程（批量）
     * @param bootThreads boot线程
     */
    public static void startBootThread(Thread... bootThreads) {
        for (Thread bootThread : bootThreads) {
            startBootThread(bootThread);
        }
    }

    /**
     * 启动子线程
     * @param key 用boot线程名作为key
     * @param childThread 线程
     */
    public static void startChildThread(String key, Thread childThread) {
        childThread.start();
        addChildThread(key, childThread);
    }

    /**
     * 启动子线程（批量）
     * @param key 用boot线程名作为key
     * @param childThreads 线程
     */
    public static void startChildThread(String key, Thread... childThreads) {
        for (Thread childThread : childThreads) {
            startChildThread(key, childThread);
        }
    }

    /**
     * 添加 boot 线程
     * @param key 用 boot 线程名作为 key
     * @param bootThread boot 线程
     */
    public static void addBootThread(String key, Thread bootThread) {
        clearTerminatedThread();
        bootThreadsMap.put(key, bootThread);
    }

    /**
     * 添加 child 线程
     * @param key 用 boot 线程名作为 key
     * @param childThread boot 线程
     */
    public static void addChildThread(String key, Thread childThread) {
        clearTerminatedThread();
        lock.lock();
        if (childThreadsMap.containsKey(key)) {
            Set<Thread> threadsSet = childThreadsMap.get(key);
            if (CollectionUtils.isEmpty(threadsSet)) {
                threadsSet = new HashSet<>();
                threadsSet.add(childThread);
                childThreadsMap.put(key, threadsSet);
                lock.unlock();
                return;
            }
            threadsSet.add(childThread);
            lock.unlock();
            return;
        }
        Set<Thread> threadsSet = new HashSet<>();
        threadsSet.add(childThread);
        childThreadsMap.put(key, threadsSet);
        lock.unlock();
    }

    /**
     * 添加 child 线程（批量）
     * @param key 用 boot 线程名作为 key
     * @param set boot 线程
     */
    public static void addChildThreadBatch(String key, HashSet<Thread> set) {
        clearTerminatedThread();
        lock.lock();
        if (childThreadsMap.containsKey(key)) {
            Set<Thread> threadsSet = childThreadsMap.get(key);
            if (CollectionUtils.isEmpty(threadsSet)) {
                childThreadsMap.put(key, set);
                lock.unlock();
                return;
            }
            childThreadsMap.get(key).addAll(set);
            lock.unlock();
            return;
        }
        childThreadsMap.put(key, set);
        lock.unlock();
    }

    /**
     * 清除已经结束的线程（非停止）
     * @param key 用 boot 线程名作为 key
     */
    public static void clearTerminatedThreadByKey(String key) {
        bootThreadsMap.remove(key);
        Set<Thread> threads = childThreadsMap.get(key);
        if (threads != null) {
            for (Thread thread : threads) {
                thread.interrupt();
            }
            childThreadsMap.remove(key);
        }

        String[] split = key.split("-");
        runningRidSet.remove(split[4]);
    }

    /**
     * 清除已经结束的线程，如果boot线程已经结束，若对应child线程未结束，则强制停止child线程
     */
    public static void clearTerminatedThread() {
        logger.debug("clearTerminatedThread");

        // 过滤出已经结束的boot线程组
        Map<String, Thread> terminatedThread = bootThreadsMap.entrySet().stream()
                .filter(t -> !t.getValue().isAlive())
                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

        // 停止并删除boot线程
        terminatedThread.forEach((k, v) -> {
            v.interrupt();
            bootThreadsMap.remove(k);

            String[] split = k.split("-");
            runningRidSet.remove(split[4]);
        });

        // 删除boot衍生的线程
        terminatedThread.forEach((key, value) -> {
            childThreadsMap.remove(key);
            String[] split = key.split("-");
            runningRidSet.remove(split[4]);
        });
    }

    /**
     * 按照结果（rid）、用例（cid）强制停止测试套件
     * @param resultId 结果id
     * @param caseId 用例id
     */
    public static void forceStopSuite(int type, int resultId, int caseId) {
        String key = "";
        if (type == CatalogType.INTERFACE_TEST_CASE.getValue()) {
            key = String.format(TesterTaskBootThread.TEST_TASK_BOOT_PRE, resultId, caseId);
        } else if (type == CatalogType.PERFORMANCE_TEST_SCRIPT.getValue()) {
            key = String.format(PerfTesterTaskBootThread.PERF_TEST_TASK_BOOT_PRE, resultId, caseId);
        } else {
            throw new MyException("stopSuite 失败，不识别的 type: " + type);
        }

        // 停止 boot 线程
        Thread bootThread = bootThreadsMap.get(key);
        if (bootThread != null) {
            bootThread.interrupt();
        }

        // 清理 map
        bootThreadsMap.remove(key);
        Set<Thread> removed = childThreadsMap.remove(key);
        if (!CollectionUtils.isEmpty(removed)) {
            for (Thread thread : removed) {
                if (thread instanceof RunStepThread) {
                    ((RunStepThread) thread).setStopped(true);
                }
            }
        }
        runningTestsMap.remove(resultId + "");

        runningRidSet.remove("rid" + resultId);
    }

    /**
     * 强制停止debug步骤线程
     * 因为目前的websocket会用当前所属线程做一些事，强制停止会导致一些问题
     * @param key 用 boot 线程名作为 key
     */
    public static void forceStopDebugStepThread(String key) {
        Set<Thread> threads = childThreadsMap.get(key);
        if (threads == null) {
            return;
        }
        for (Thread thread : threads) {
            RunStepThread runStepThread = (RunStepThread) thread;
            runStepThread.setStopped(true);
        }
        childThreadsMap.remove(key);
    }
}
